{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<IBMBackend('ibm_kyiv')>\n"
     ]
    }
   ],
   "source": [
    "from qiskit.circuit import QuantumCircuit, QuantumRegister, ClassicalRegister\n",
    "import qiskit.circuit.library as qulib\n",
    "from qiskit.qasm3 import dump, dumps, Exporter\n",
    "import numpy as np\n",
    "import random\n",
    "from typing import List, Union\n",
    "from qiskit_ibm_runtime import QiskitRuntimeService\n",
    "from qiskit import generate_preset_pass_manager\n",
    "from qiskit_ibm_runtime import QiskitRuntimeService, SamplerV2 as Sampler\n",
    "from qmg.utils import MoleculeQuantumStateGenerator, ConditionalWeightsGenerator\n",
    "import pandas as pd\n",
    "from rdkit import RDLogger\n",
    "RDLogger.DisableLog('rdApp.*')\n",
    "\n",
    "def get_token(file_path):\n",
    "    with open(file_path) as f:\n",
    "        data = f.read()\n",
    "    token = data.strip()\n",
    "    return token\n",
    "\n",
    "my_token = get_token(\"./docs/ibmq_tokens.txt\")\n",
    "service = QiskitRuntimeService(channel=\"ibm_quantum\", token=my_token)\n",
    " \n",
    "backend = service.least_busy(simulator=False, operational=True)\n",
    "pm = generate_preset_pass_manager(backend=backend, optimization_level=1)\n",
    "print(backend)\n",
    "\n",
    "num_heavy_atom = 9\n",
    "num_parameters = int(8 + 3*(num_heavy_atom - 2) * (num_heavy_atom + 3) / 2)\n",
    "qubit_register_dict = {\"atom_1\": QuantumRegister(2, name=\"atom_1\"),\n",
    "                       \"atom_i\": QuantumRegister(2, name=\"atom_i\")}\n",
    "for i in range(1, num_heavy_atom):\n",
    "    qubit_register_dict.update({f\"bond_{i}\": QuantumRegister(2, name=f\"bond_{i}\")})\n",
    "\n",
    "clbit_register_dict = {}\n",
    "for i in range(1, num_heavy_atom+1):\n",
    "    clbit_register_dict.update({f\"atom_{i}_m\": ClassicalRegister(2, name=f\"atom_{i}_m\")})\n",
    "    for j in range(1, i):\n",
    "        clbit_register_dict.update({f\"bond_{i}_{j}_m\": ClassicalRegister(2, name=f\"bond_{i}_{j}_m\")})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get weight vector\n",
    "trial_df = pd.read_csv(\"results_chemistry_constraint_bo/unconditional_9_4.csv\")\n",
    "index = 501\n",
    "\n",
    "cwg = ConditionalWeightsGenerator(num_heavy_atom=num_heavy_atom, smarts=None, disable_connectivity_position=None)\n",
    "random_weight_vector = np.zeros(cwg.length_all_weight_vector)\n",
    "inputs = random_weight_vector\n",
    "number_flexible_parameters = len(random_weight_vector[cwg.parameters_indicator == 0.])\n",
    "partial_inputs = np.array(trial_df[trial_df[\"trial_index\"] == index][[f\"x{i+1}\" for i in range(number_flexible_parameters)]])[0]\n",
    "inputs[cwg.parameters_indicator == 0.] = partial_inputs\n",
    "inputs = cwg.apply_chemistry_constraint(inputs)\n",
    "weight_vector = inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.        , 0.93771391, 0.33959516, 0.87208743, 0.        ,\n",
       "       1.        , 0.        , 1.        , 0.92984925, 0.49968326,\n",
       "       1.        , 0.99330715, 0.00669285, 0.        , 0.5       ,\n",
       "       0.        , 1.        , 0.55639786, 1.        , 0.        ,\n",
       "       0.95335631, 0.03454501, 0.01209868, 0.        , 1.        ,\n",
       "       0.47865948, 0.5       , 0.        , 1.        , 0.61658179,\n",
       "       0.99936443, 0.46670468, 0.01108171, 0.00579138, 0.12360936,\n",
       "       0.85951755, 0.34363469, 0.52658671, 0.5       , 1.        ,\n",
       "       0.        , 0.93303569, 0.        , 1.        , 0.40874672,\n",
       "       1.        , 0.        , 0.49225377, 0.49225377, 0.00752682,\n",
       "       0.00331678, 0.00464886, 0.        , 0.5       , 0.        ,\n",
       "       0.78951158, 0.        , 1.        , 0.23592222, 1.        ,\n",
       "       0.43559468, 1.        , 0.36484118, 0.91858667, 0.04533939,\n",
       "       0.00331186, 0.49152386, 0.49152386, 0.00701669, 0.00331186,\n",
       "       0.00331186, 0.40438467, 0.62031226, 0.        , 0.92181379,\n",
       "       0.        , 0.5       , 0.35310907, 0.92719022, 0.05766601,\n",
       "       0.84058179, 0.        , 0.96460832, 0.62962534, 0.8172493 ,\n",
       "       0.58875864, 0.00411376, 0.00327224, 0.00327224, 0.48564367,\n",
       "       0.00476041, 0.48564367, 0.013294  , 0.5       , 0.690049  ,\n",
       "       0.        , 0.87897038, 0.46079699, 0.5       , 0.        ,\n",
       "       0.5       , 0.        , 0.87246189, 0.        , 1.        ,\n",
       "       0.5       , 0.5       , 0.59502753, 0.60525537, 1.        ,\n",
       "       0.01256196, 0.45408702, 0.00399529, 0.00600568, 0.05452145,\n",
       "       0.01063925, 0.45408702, 0.00410234, 0.30966776, 1.        ,\n",
       "       0.04723178, 1.        , 0.01546553, 0.5       , 0.        ,\n",
       "       1.        , 0.20819897, 1.        , 0.5       , 0.5       ,\n",
       "       0.        , 0.50411812, 0.5       , 0.5       ])"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weight_vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(qubit_register_dict.values())\n",
    "circuit = QuantumCircuit(*list(qubit_register_dict.values()), *list(clbit_register_dict.values()))\n",
    "\n",
    "def controlled_ry(control:int, target:int, digit:float):\n",
    "    circuit.cry(np.pi*digit, control, target)\n",
    "\n",
    "def reset_q_register(circuit, q_register, measures):\n",
    "    \"\"\"Reset the controlling qubits if they are in |1>.\"\"\"\n",
    "    with circuit.if_test((measures[0], True)):\n",
    "        circuit.x(q_register[0])\n",
    "    with circuit.if_test((measures[1], True)):\n",
    "        circuit.x(q_register[1])\n",
    "\n",
    "def reset_block(circuit, heavy_atom_idx):\n",
    "    reset_q_register(circuit, qubit_register_dict[\"atom_i\"],  clbit_register_dict[\"atom_2_m\"])\n",
    "    for k in range(1, heavy_atom_idx):\n",
    "        reset_q_register(circuit, qubit_register_dict[f\"bond_{k}\"],  clbit_register_dict[f\"bond_{heavy_atom_idx}_{k}_m\"])\n",
    "\n",
    "def measure_bond_block(circuit, heavy_atom_idx):\n",
    "    for k in range(1, heavy_atom_idx):\n",
    "        circuit.measure(qubit_register_dict[f\"bond_{k}\"],  clbit_register_dict[f\"bond_{heavy_atom_idx}_{k}_m\"])\n",
    "\n",
    "# first two atoms part\n",
    "circuit.ry(np.pi * weight_vector[0], 0)\n",
    "circuit.x(1)\n",
    "circuit.ry(np.pi * weight_vector[2], 2)\n",
    "circuit.ry(np.pi * weight_vector[4], 3)\n",
    "circuit.cx(0, 1)\n",
    "controlled_ry(1, 2, weight_vector[3])\n",
    "circuit.cx(2, 3)\n",
    "controlled_ry(0, 1, weight_vector[1])\n",
    "circuit.cx(1, 2)\n",
    "controlled_ry(2, 3, weight_vector[5])\n",
    "\n",
    "circuit.measure(qubit_register_dict[\"atom_1\"], clbit_register_dict[\"atom_1_m\"])\n",
    "circuit.measure(qubit_register_dict[\"atom_i\"], clbit_register_dict[\"atom_2_m\"])\n",
    "with circuit.if_test((clbit_register_dict[\"atom_2_m\"], 0)) as else_:\n",
    "    pass\n",
    "with else_:\n",
    "    circuit.ry(np.pi * weight_vector[6], 4)\n",
    "    circuit.x(5)\n",
    "    circuit.cx(4,5)\n",
    "    controlled_ry(4, 5, weight_vector[7])\n",
    "circuit.measure(qubit_register_dict[\"bond_1\"], clbit_register_dict[\"bond_2_1_m\"])\n",
    "\n",
    "used_part = 8\n",
    "heavy_atom_idx = 3\n",
    "# Third atom and recrusive part\n",
    "for heavy_atom_idx in range(3, num_heavy_atom+1):\n",
    "    # reset qubits for qubit reuse\n",
    "    reset_block(circuit, heavy_atom_idx - 1) # heavy atom idx starts with 2\n",
    "    with circuit.if_test((clbit_register_dict[f\"atom_{heavy_atom_idx-1}_m\"], 0)) as else_:\n",
    "        pass\n",
    "    with else_:\n",
    "        circuit.ry(np.pi * weight_vector[used_part], qubit_register_dict[\"atom_i\"][0])\n",
    "        circuit.ry(np.pi * weight_vector[used_part+1], qubit_register_dict[\"atom_i\"][1])\n",
    "        controlled_ry(qubit_register_dict[\"atom_i\"][0], qubit_register_dict[\"atom_i\"][1], weight_vector[used_part+2])\n",
    "    circuit.measure(qubit_register_dict[\"atom_i\"], clbit_register_dict[f\"atom_{heavy_atom_idx}_m\"])\n",
    "    with circuit.if_test((clbit_register_dict[f\"atom_{heavy_atom_idx}_m\"], 0)) as else_:\n",
    "        pass\n",
    "    with else_:\n",
    "        num_fixed = heavy_atom_idx-1\n",
    "        num_flexible = 2*num_fixed\n",
    "        bond_type_fixed_part = weight_vector[used_part+3: used_part+3+num_fixed]\n",
    "        bond_type_flexible_part = weight_vector[used_part+3+num_fixed: used_part+3+num_fixed+num_flexible]\n",
    "        for i in range(heavy_atom_idx-1):\n",
    "            circuit.ry(np.pi * bond_type_fixed_part[i], qubit_register_dict[f\"bond_{i+1}\"][1])\n",
    "            controlled_ry(qubit_register_dict[f\"bond_{i+1}\"][1], qubit_register_dict[f\"bond_{i+1}\"][0], bond_type_flexible_part[2*i]) # < 0.5\n",
    "            controlled_ry(qubit_register_dict[f\"bond_{i+1}\"][0], qubit_register_dict[f\"bond_{i+1}\"][1], bond_type_flexible_part[2*i+1]) # > 0.5\n",
    "\n",
    "    measure_bond_block(circuit, heavy_atom_idx)\n",
    "    used_part += 3*heavy_atom_idx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> Job ID: cygq8tk9b62g00822s70\n"
     ]
    }
   ],
   "source": [
    "num_sample = 5000\n",
    "\n",
    "transpiled_qc = pm.run(circuit)\n",
    "sampler = Sampler(mode=backend)\n",
    "sampler.options.default_shots = num_sample\n",
    "job = sampler.run([transpiled_qc])\n",
    "\n",
    "print(f\">>> Job ID: {job.job_id()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No. of samples: 50000\n"
     ]
    }
   ],
   "source": [
    "from qiskit_ibm_runtime import QiskitRuntimeService\n",
    "\n",
    "def get_token(file_path):\n",
    "    with open(file_path) as f:\n",
    "        data = f.read()\n",
    "    token = data.strip()\n",
    "    return token\n",
    "\n",
    "service = QiskitRuntimeService(\n",
    "    channel='ibm_quantum',\n",
    "    instance='ibm-q/open/main',\n",
    "    token=get_token(\"./docs/ibmq_tokens.txt\")\n",
    ")\n",
    "job = service.job('cyhr3ex01rbg008fvxc0')\n",
    "job_result = job.result()\n",
    "num_sample = job_result[0].data.atom_1_m.num_shots\n",
    "print(\"No. of samples:\", num_sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [],
   "source": [
    "def post_process_results(job_result, num_heavy_atom):\n",
    "    num_shots = job_result[0].data.atom_1_m.num_shots\n",
    "    quantum_state_list = [\"\"]*num_shots\n",
    "    # node vector\n",
    "    for i in range(1, num_heavy_atom+1):\n",
    "        key = f\"atom_{i}_m\"\n",
    "        partial_results = job_result[0].data[key].get_bitstrings()\n",
    "        for z in range(num_shots):\n",
    "            quantum_state_list[z] = quantum_state_list[z] + partial_results[z][::-1]\n",
    "        \n",
    "        # bond adjacency\n",
    "        for j in range(1, i):\n",
    "            key = f\"bond_{i}_{j}_m\"\n",
    "            partial_results = job_result[0].data[key].get_bitstrings()\n",
    "            for z in range(num_shots):\n",
    "                quantum_state_list[z] = quantum_state_list[z] + partial_results[z][::-1]\n",
    "                # remove bond disconnection\n",
    "                if i - j == 1:\n",
    "                    if (quantum_state_list[z][-2*j-2:-2*j] != \"00\") and (quantum_state_list[z][-2*j:] == \"0\"*(2*j)):\n",
    "                        quantum_state_list[z] = quantum_state_list[z][:-1]+\"1\"\n",
    "    return quantum_state_list\n",
    "\n",
    "quantum_state_list = post_process_results(job_result, num_heavy_atom=9)\n",
    "mg = MoleculeQuantumStateGenerator(heavy_atom_size=9)\n",
    "quantum_state_list = [mg.post_process_quantum_state(qs, reverse=False) for qs in quantum_state_list]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validity:  0.1078\n",
      "Uniqueness:  0.0944\n"
     ]
    }
   ],
   "source": [
    "from collections import Counter\n",
    "smiles_list = []\n",
    "for qs in quantum_state_list:\n",
    "    smiles = mg.QuantumStateToSmiles(qs)\n",
    "    if smiles:\n",
    "        smiles_list.append(smiles)\n",
    "smiles_dict = Counter(smiles_list)\n",
    "print(\"Validity: \", len(smiles_list) / num_sample)\n",
    "print(\"Uniqueness: \", len(smiles_dict) / num_sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1, C, 52\n",
      "2, O, 12\n",
      "3, N, 4\n",
      "4, C[C@@H](O)[C@@H](CO)NO, 2\n",
      "5, C[C@@](O)(O)CCO, 2\n",
      "6, CC[C@@](C)(O)OC, 1\n",
      "7, CC[C@@H]1[C@@H](N)[C@@]1(O)O, 1\n",
      "8, CCN(O)[C@H](C)[C@@H](N)N, 1\n",
      "9, CO[C@]1(O)C[C@H]1C, 1\n",
      "10, C[C@]1(O)ON[C@@H]1CO, 1\n",
      "11, CC[C@@H](N)[C@@H](N)[C@@H](N)N, 1\n",
      "12, N[C@H](N(N)O)N1CCO1, 1\n",
      "13, C[C@H]1N[N@]2C[C@@]1(N)NN2, 1\n",
      "14, CN[C@H](NC)[C@H](N)CN, 1\n",
      "15, C[C@]1(N)C[C@@]1(O)O, 1\n",
      "16, CN1C[C@@H]([C@@H](O)O)C1, 1\n",
      "17, N=N[C@H](NN)NCN, 1\n",
      "18, CC[N@@]1N[C@]1(C)O, 1\n",
      "19, CC[C@H](NN)[C@@H](N)O, 1\n",
      "20, N[C@@H]1[C@@H](O)C[C@@]2(N)C[C@@H]12, 1\n"
     ]
    }
   ],
   "source": [
    "def print_top_k(counter_data, k=10):\n",
    "    top_k = counter_data.most_common(k)\n",
    "    for idx, (key, count) in enumerate(top_k, start=1):\n",
    "        print(f\"{idx}, {key}, {count}\")\n",
    "\n",
    "print_top_k(smiles_dict, 20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "qmg-n",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.1.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
